# coding: utf-8


import tensorflow as tf
import numpy as np
from math import ceil
import os
import shutil
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt


def get_image_list(file_dir):
    # 获取图片列表
    image_list = os.listdir(file_dir)
    # 打印图片数量
    print('len(image_list):', len(image_list))
    # 乱序
    np.random.shuffle(image_list)
    image_list = [os.path.join(file_dir, x) for x in image_list]
    print(image_list[:10])
    return image_list


def get_batch(image_list, image_W, image_H, batch_size, capacity):
    """
        image_list: 图片的 list
        image_W: 图片的宽度
        image_H: 图片的高度
        batch_size: 每个 batch 的图片数
        capacity: 队列的容量
    """
    # tf.cast(x, dtype, name=None): Casts a tensor to a new type.
    image_list = tf.cast(image_list, tf.string)

    # 创建队列
    input_queue = tf.train.slice_input_producer([image_list],
                                                num_epochs=None,
                                                shuffle=True,
                                                seed=None,
                                                capacity=capacity)
    image_contents = tf.read_file(input_queue[0])
    # print(image_contents)
    image = tf.image.decode_jpeg(image_contents, channels=1)
    image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)  # 统一图片大小，裁剪/填充
    # print(image)
    image = tf.cast(image, tf.float32)
    # image = image / 255
    # image = tf.image.per_image_standardization(image)   # 将整幅图片标准化，加速神经网络的训练
    # 得到 batch
    image_batch = tf.train.batch([image], batch_size=batch_size, num_threads=8, capacity=capacity)
    return image_batch


def test():
    """
    功能测试
    """
    import matplotlib.pyplot as plt
    BATCH_SIZE = 64
    CAPACITY = 256  # 队列容量
    IMG_W = 512  # 图片宽度
    IMG_H = 512  # 图片高度
    NORMAL_DIR = '../imgs/normal_img/'  # 正常图片路径
    # 获取图片文件列表
    image_list = get_image_list(NORMAL_DIR)
    image_batch = get_batch(image_list, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)
    # 创建 session
    with tf.Session() as sess:
        # i 控制只进行一个 batch 的测试
        i = 0
        # 定义图片编号，从 1 开始
        img_num = 1
        # 定义线程协调器
        coord = tf.train.Coordinator()
        # Starts all queue runners collected in the graph.
        threads = tf.train.start_queue_runners(coord=coord)
        # 如果文件夹不存在则创建文件夹
        if not os.path.exists('./input_data_test'):
            os.makedirs('./input_data_test')
        try:
            # i 控制只进行一个 batch 的测试
            while not coord.should_stop() and i < 1:
                img = sess.run([image_batch])
                print(np.array(img).shape)
                print(np.array(img[0][0]))
                print(np.array(img[0][0]).shape)
                for j in np.arange(BATCH_SIZE):
                    plt.figure()
                    image = np.array(img[0][0])
                    image = np.squeeze(image)
                    print(image.shape)
                    plt.imshow(image, cmap ='gray')
                    plt.grid(False)
                    plt.savefig('./input_data_test/{}.png'.format(img_num))
                    img_num += 1
                i += 1
        except tf.errors.OutOfRangeError:
            print("except tf.errors.OutOfRangeError.")
        finally:
            coord.request_stop()  # 发出终止所有线程的命令
        coord.join(threads)  # 等待 threads 结束


if __name__ == '__main__':
    test()
